{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Saving and Loading Models\n",
    "\n",
    "In this bite-sized notebook, we'll go over how to save and load models. In general, the process is the same as for any PyTorch module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving a Simple Model\n",
    "\n",
    "First, we define a GP Model that we'd like to save. The model used below is the same as the model from our\n",
    "<a href=\"../01_Exact_GPs/Simple_GP_Regression.ipynb\">Simple GP Regression</a> tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = torch.linspace(0, 1, 100)\n",
    "train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We will use the simplest form of GP model, exact inference\n",
    "class ExactGPModel(gpytorch.models.ExactGP):\n",
    "    def __init__(self, train_x, train_y, likelihood):\n",
    "        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "    \n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "# initialize likelihood and model\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "model = ExactGPModel(train_x, train_y, likelihood)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Change Model State\n",
    "\n",
    "To demonstrate model saving, we change the hyperparameters from the default values below. For more information on what is happening here, see our tutorial notebook on <a href=\"Hyperparameters.ipynb\">Initializing Hyperparameters</a>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.covar_module.outputscale = 1.2\n",
    "model.covar_module.base_kernel.lengthscale = 2.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting Model State\n",
    "\n",
    "To get the full state of a GPyTorch model, simply call `state_dict` as you would on any PyTorch model. Note that the state dict contains **raw** parameter values. This is because these are the actual `torch.nn.Parameters` that are learned in GPyTorch. Again see our notebook on hyperparamters for more information on this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),\n",
       "             ('mean_module.constant', tensor([0.])),\n",
       "             ('covar_module.raw_outputscale', tensor(0.8416)),\n",
       "             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]]))])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.state_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving Model State\n",
    "\n",
    "The state dictionary above represents all traininable parameters for the model. Therefore, we can save this to a file as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'model_state.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading Model State\n",
    "\n",
    "Next, we load this state in to a new model and demonstrate that the parameters were updated correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "IncompatibleKeys(missing_keys=[], unexpected_keys=[])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state_dict = torch.load('model_state.pth')\n",
    "model = ExactGPModel(train_x, train_y, likelihood)  # Create a new GP model\n",
    "\n",
    "model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),\n",
       "             ('mean_module.constant', tensor([0.])),\n",
       "             ('covar_module.raw_outputscale', tensor(0.8416)),\n",
       "             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]]))])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.state_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A More Complex Example\n",
    "\n",
    "Next we demonstrate this same principle on a more complex exact GP where we have a simple feed forward neural network feature extractor as part of the model.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPWithNNFeatureExtractor(gpytorch.models.ExactGP):\n",
    "    def __init__(self, train_x, train_y, likelihood):\n",
    "        super(GPWithNNFeatureExtractor, self).__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "        \n",
    "        self.feature_extractor = torch.nn.Sequential(\n",
    "            torch.nn.Linear(1, 2),\n",
    "            torch.nn.BatchNorm1d(2),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(2, 2),\n",
    "            torch.nn.BatchNorm1d(2),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.feature_extractor(x)\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "# initialize likelihood and model\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "model = GPWithNNFeatureExtractor(train_x, train_y, likelihood)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting Model State\n",
    "\n",
    "In the next cell, we once again print the model state via `model.state_dict()`. As you can see, the state is substantially more complex, as the model now includes our neural network parameters. Nevertheless, saving and loading is straight forward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),\n",
       "             ('mean_module.constant', tensor([0.])),\n",
       "             ('covar_module.raw_outputscale', tensor(0.)),\n",
       "             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),\n",
       "             ('feature_extractor.0.weight', tensor([[-0.9135],\n",
       "                      [-0.5942]])),\n",
       "             ('feature_extractor.0.bias', tensor([ 0.9119, -0.0663])),\n",
       "             ('feature_extractor.1.weight', tensor([0.2263, 0.2209])),\n",
       "             ('feature_extractor.1.bias', tensor([0., 0.])),\n",
       "             ('feature_extractor.1.running_mean', tensor([0., 0.])),\n",
       "             ('feature_extractor.1.running_var', tensor([1., 1.])),\n",
       "             ('feature_extractor.1.num_batches_tracked', tensor(0)),\n",
       "             ('feature_extractor.3.weight', tensor([[-0.6375, -0.6466],\n",
       "                      [-0.0563, -0.4695]])),\n",
       "             ('feature_extractor.3.bias', tensor([-0.1247,  0.0803])),\n",
       "             ('feature_extractor.4.weight', tensor([0.0466, 0.7248])),\n",
       "             ('feature_extractor.4.bias', tensor([0., 0.])),\n",
       "             ('feature_extractor.4.running_mean', tensor([0., 0.])),\n",
       "             ('feature_extractor.4.running_var', tensor([1., 1.])),\n",
       "             ('feature_extractor.4.num_batches_tracked', tensor(0))])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "IncompatibleKeys(missing_keys=[], unexpected_keys=[])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.save(model.state_dict(), 'my_gp_with_nn_model.pth')\n",
    "state_dict = torch.load('my_gp_with_nn_model.pth')\n",
    "model = GPWithNNFeatureExtractor(train_x, train_y, likelihood)\n",
    "model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),\n",
       "             ('mean_module.constant', tensor([0.])),\n",
       "             ('covar_module.raw_outputscale', tensor(0.)),\n",
       "             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),\n",
       "             ('feature_extractor.0.weight', tensor([[-0.9135],\n",
       "                      [-0.5942]])),\n",
       "             ('feature_extractor.0.bias', tensor([ 0.9119, -0.0663])),\n",
       "             ('feature_extractor.1.weight', tensor([0.2263, 0.2209])),\n",
       "             ('feature_extractor.1.bias', tensor([0., 0.])),\n",
       "             ('feature_extractor.1.running_mean', tensor([0., 0.])),\n",
       "             ('feature_extractor.1.running_var', tensor([1., 1.])),\n",
       "             ('feature_extractor.1.num_batches_tracked', tensor(0)),\n",
       "             ('feature_extractor.3.weight', tensor([[-0.6375, -0.6466],\n",
       "                      [-0.0563, -0.4695]])),\n",
       "             ('feature_extractor.3.bias', tensor([-0.1247,  0.0803])),\n",
       "             ('feature_extractor.4.weight', tensor([0.0466, 0.7248])),\n",
       "             ('feature_extractor.4.bias', tensor([0., 0.])),\n",
       "             ('feature_extractor.4.running_mean', tensor([0., 0.])),\n",
       "             ('feature_extractor.4.running_var', tensor([1., 1.])),\n",
       "             ('feature_extractor.4.num_batches_tracked', tensor(0))])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
